from emodel_generalisation.mcmc import load_chains
import seaborn as sns
from emodel_generalisation.mcmc import plot_corner
import pandas as pd
import numpy as np
import shap
from tqdm import tqdm
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import RepeatedStratifiedKFold


def _get_shap_feature_importance(shap_values):
    """From a list of shap values per folds, compute the global shap feature importance."""
    mean_shap_values = np.mean(shap_values, axis=0)
    if len(np.shape(mean_shap_values)) > 2:
        global_mean_shap_values = np.mean(mean_shap_values, axis=0)
        mean_shap_values = list(mean_shap_values)
    else:
        global_mean_shap_values = mean_shap_values

    shap_feature_importance = np.mean(abs(global_mean_shap_values), axis=0)
    return mean_shap_values, shap_feature_importance


def train(X, y, n_splits=5, n_repeats=1):
    model = XGBClassifier()
    folds = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=42)
    acc_scores = []
    shap_values = []
    for indices in tqdm(folds.split(X, y=y), total=n_splits * n_repeats):
        train_index, val_index = indices
        model.fit(X.iloc[train_index], y.iloc[train_index])
        acc_score = accuracy_score(y.iloc[val_index], model.predict(X.iloc[val_index]))
        acc_scores.append(acc_score)
        explainer = shap.TreeExplainer(model)

        shap_value = explainer.shap_values(X)
        shap_values.append(shap_value)

    return acc_scores, shap_values


def plot(X, shap_values, name):
    mean_shap_values, shap_feature_importance = _get_shap_feature_importance(shap_values)
    print(shap_feature_importance, X.columns)
    shap.summary_plot(
        mean_shap_values,
        X,
        plot_type="bar",
        max_display=5,
        show=False,
        plot_size=(8, 3),
    )
    plt.tight_layout()
    plt.savefig(f"bar_shap_{name}.pdf")
    plt.close()

    shap.summary_plot(
        mean_shap_values,
        X,
        plot_type="dot",
        max_display=5,
        show=False,
        color_bar_label="parameter value",
        plot_size=(8, 3),
    )
    plt.tight_layout()
    plt.savefig(f"dot_shap_{name}.pdf")
    plt.close()


if __name__ == "__main__":
    mcmc_df = load_chains("../../mcmc_run/run_df.csv", base_path="../../mcmc_run")
    mcmc_df = mcmc_df[mcmc_df.cost < 2.0]

    mask_tonic = mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] > 0.0

    mask_runaway = (
        ~mask_tonic
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"] < 0.05)
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 20000)
    )
    mask_ecel = (
        ~mask_tonic
        & ~mask_runaway
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] < 5)
    )
    mask_spp = (
        ~mask_tonic
        & ~mask_runaway
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] >= 5)
    )

    df = mcmc_df["normalized_parameters"]  # .sample(1000, random_state=42)

    plot_dfs = []
    strip_plot_dfs = []
    _df = df[mask_runaway].melt()
    _df["method"] = "runaway"
    plot_dfs.append(_df)
    strip_plot_dfs.append(_df.sample(200, random_state=42))

    _df = df[mask_spp].melt()
    _df["method"] = "spp"
    plot_dfs.append(_df)
    strip_plot_dfs.append(_df.sample(200, random_state=42))

    _df = df[mask_ecel].melt()
    _df["method"] = "ecel"
    plot_dfs.append(_df)
    strip_plot_dfs.append(_df.sample(200, random_state=42))

    _df = df[mask_tonic].melt()
    _df["method"] = "tonic"
    plot_dfs.append(_df)
    strip_plot_dfs.append(_df.sample(200, random_state=42))

    plot_df = pd.concat(plot_dfs)
    strip_plot_df = pd.concat(strip_plot_dfs)

    plt.figure(figsize=(10, 4))
    ax = plt.gca()
    order = sorted(plot_df["variable"].unique())
    order = [
        "g_pas.all",
        "gcabar_it2.basal",
        "shift_it2.somadend",
        "gkbar_iahp.basal",
        "gbar_ican.basal",
        "gnabar_hh2.somatic",
        "gkbar_hh2.somatic",
        "constant.distribution_increase",
        "gcabar_it2.somatic",
        "gkbar_iahp.somatic",
        "gk_max_iA.basal",
    ]
    print(order)
    hue_order = ["ecel", "spp", "runaway", "tonic"]
    sns.boxplot(
        data=plot_df,
        x="variable",
        y="value",
        hue="method",
        whis=[10, 90],
        showfliers=False,
        ax=ax,
        hue_order=hue_order,
        palette=["C0", "C1", "C2", "C3"],
        order=order,
        # gap=1.1,
    )
    plt.xticks(rotation=45)
    ax.set_ylim(-1, 1)
    plt.tight_layout()
    plt.savefig("boxplot_params.pdf")
    plt.close()

    plot_corner(mcmc_df[mask_runaway], feature=None, filename="corner_runaway.pdf")
    plt.close()
    plot_corner(mcmc_df[mask_spp], feature=None, filename="corner_spp.pdf")
    plt.close()
    plot_corner(mcmc_df[mask_ecel], feature=None, filename="corner_ecel.pdf")
    plt.close()
    plot_corner(mcmc_df[mask_tonic], feature=None, filename="corner_tonic.pdf")
    plt.close()

    X = mcmc_df["normalized_parameters"][~mask_runaway].reset_index(drop=True)
    y = np.zeros(len(X))
    y[mask_spp[~mask_runaway]] = 1
    y = pd.DataFrame(y)

    acc_scores, shap_values = train(X, y)
    print(acc_scores)
    print(f"score of {np.mean(acc_scores)} +- {np.std(acc_scores)} for spp/ecel")
    plot(X, shap_values, name="spp_ecel")

    X = mcmc_df["normalized_parameters"].reset_index()
    y = np.zeros(len(X))
    y[mask_runaway] = 1
    y = pd.DataFrame(y)

    acc_scores, shap_values = train(X, y)
    print(f"score of {np.mean(acc_scores)} +- {np.std(acc_scores)} for runaway")
    plot(X, shap_values, name="runaway")

    X = mcmc_df["normalized_parameters"].reset_index()
    y = np.zeros(len(X))
    y[mask_tonic] = 1
    y = pd.DataFrame(y)

    acc_scores, shap_values = train(X, y)
    print(f"score of {np.mean(acc_scores)} +- {np.std(acc_scores)} for tonic")
    plot(X, shap_values, name="tonic")
